import numpy as np
from matplotlib import pyplot as plt
import matplotlib
from matplotlib.patches import Polygon
from matplotlib.collections import PatchCollection
import matplotlib.gridspec as gridspec

#####################################################################################################################################################################
#   Code for producing Figure 3--Figure Supplement 7.
#   The following data is used in this figure:
#   -- diffusion_kl_divergence.txt, KL-divergence between the free head distributions produced by the cosine potential and
#            the quartic potential with cubic coefficient h_3 and quartic coefficient h_4
#   -- diffusion_contour_alt_potential_quadratic.txt, x-z projection of the diffusion contour corresponding to the quadratic potential (h_3=h_4=0)
#   -- diffusion_contour_alt_potential_2_5.txt, x-z projection of the diffusion contour corresponding to the quartic potential with h_3=2 and h_4=5
#   -- diffusion_contour_alt_potential_-2_-2.txt, x-z projection of the diffusion contour corresponding to the quartic potential with h_3=-2 and h_4=-2
#
#####################################################################################################################################################################



maxLevel = 4.
levels = np.linspace(-0.00000000001*10**(-4),6.75*10**(-4),14)
levelsKL = np.linspace(-6.5,-0.0,14)
print(levels)
axisLabelSize = 18
pltLabelSize=15


params = {'backend': 'ps'}

plt.rcParams.update(params)


fig = plt.figure(figsize=(17*10/9/2, 7.5))




outer = gridspec.GridSpec(8, 10, wspace=2.8,hspace=1.2)

theory = gridspec.GridSpecFromSubplotSpec(2, 2,
                    subplot_spec=outer[:8,1:-1], wspace=0.35)



cbar = gridspec.GridSpecFromSubplotSpec(1, 1,
                    subplot_spec=outer[1:7,-1])
cbar2 = gridspec.GridSpecFromSubplotSpec(1, 1,
                    subplot_spec=outer[0:4,0])



ax = plt.Subplot(fig, theory[0, 0])
fig.add_subplot(ax)

h3, h4, kldiv = np.transpose(np.loadtxt('kl_divergence_contours/diffusion_kl_divergence.txt'))

numz = np.count_nonzero(h3==h3[1])
numx = np.count_nonzero(h4==h4[1])
h3 = h3[::numz]
h4 = h4[:int(len(h4)/numx)]
kldiv = np.transpose(np.reshape(kldiv,[numx, numz]))
kldiv = kldiv
colorBarAxs = ax.contourf(h3,h4, np.log10(kldiv), levelsKL, cmap=plt.cm.bone)
# ax.contour(z,x, probxz, levels, colors='k')
ax.set_xlabel(r'$h_3$', size=axisLabelSize,labelpad=0)
ax.set_ylabel(r'$h_4$', size=axisLabelSize,labelpad=-2)
ax.text(0.01, 0.99, 'A', size=pltLabelSize, transform=ax.transAxes,
              horizontalalignment='left', verticalalignment='top')
ax.plot([0,2,-2],[0,5,-2], 'k*')
ax.text(0.1,0.1, 'B', size=pltLabelSize-3)
ax.text(2.1,5.1, 'C', size=pltLabelSize-3)
ax.text(-1.9,-1.9, 'D', size=pltLabelSize-3)



ax = plt.Subplot(fig, theory[0, 1])
fig.add_subplot(ax)

z, x, probxz = np.transpose(np.loadtxt('kl_divergence_contours/diffusion_contour_alt_potential_quadratic.txt'))

numz = np.count_nonzero(z==z[1])
numx = np.count_nonzero(x==x[1])
z = z[::numz]
x = x[:int(len(x)/numx)]
probxz = np.transpose(np.reshape(probxz,[numx, numz]))
probxz = probxz

colorBarAxs = ax.contourf(z,x, probxz, levels, cmap=plt.cm.YlOrBr)
ax.contour(z,x, probxz, levels, colors='k')
ax.set_xlabel(r'$z$', size=axisLabelSize,labelpad=0)
ax.set_ylabel(r'$x$', size=axisLabelSize,labelpad=-2)
ax.text(0.99,0.99, r'$\mathcal{P}(x,z)$', size=pltLabelSize, transform=ax.transAxes, horizontalalignment='right', verticalalignment='top')
ax.text(0.01, 0.99, 'B', size=pltLabelSize, transform=ax.transAxes,
              horizontalalignment='left', verticalalignment='top')
ax.text(0.01, 0.01, r'$h_3 = 0, h_4 = 0$', size=pltLabelSize, transform=ax.transAxes, horizontalalignment='left', verticalalignment='bottom')


ax = plt.Subplot(fig, theory[1, 0])
fig.add_subplot(ax)

z, x, probxz = np.transpose(np.loadtxt('kl_divergence_contours/diffusion_contour_alt_potential_2_5.txt'))

numz = np.count_nonzero(z==z[1])
numx = np.count_nonzero(x==x[1])
z = z[::numz]
x = x[:int(len(x)/numx)]
probxz = np.transpose(np.reshape(probxz,[numx, numz]))
probxz = probxz

colorBarAxs = ax.contourf(z,x, probxz, levels, cmap=plt.cm.YlOrBr)
ax.contour(z,x, probxz, levels, colors='k')
ax.set_xlabel(r'$z$', size=axisLabelSize,labelpad=0)
ax.set_ylabel(r'$x$', size=axisLabelSize,labelpad=-2)
ax.text(0.99,0.99, r'$\mathcal{P}(x,z)$', size=pltLabelSize, transform=ax.transAxes, horizontalalignment='right', verticalalignment='top')
ax.text(0.01, 0.99, 'C', size=pltLabelSize, transform=ax.transAxes,
              horizontalalignment='left', verticalalignment='top')
ax.text(0.01, 0.01, r'$h_3 = 2, h_4 = 5$', size=pltLabelSize, transform=ax.transAxes, horizontalalignment='left', verticalalignment='bottom')




ax = plt.Subplot(fig, theory[1, 1])
fig.add_subplot(ax)

z, x, probxz = np.transpose(np.loadtxt('kl_divergence_contours/diffusion_contour_alt_potential_-2_-2.txt'))

numz = np.count_nonzero(z==z[1])
numx = np.count_nonzero(x==x[1])
z = z[::numz]
x = x[:int(len(x)/numx)]
probxz = np.transpose(np.reshape(probxz,[numx, numz]))
probxz = probxz

colorBarAxs = ax.contourf(z,x, probxz, levels, cmap=plt.cm.YlOrBr)
ax.contour(z,x, probxz, levels, colors='k')
ax.set_xlabel(r'$z$', size=axisLabelSize,labelpad=0)
ax.set_ylabel(r'$x$', size=axisLabelSize,labelpad=-2)
ax.text(0.99,0.99, r'$\mathcal{P}(x,z)$', size=pltLabelSize, transform=ax.transAxes, horizontalalignment='right', verticalalignment='top')
ax.text(0.01, 0.99, 'D', size=pltLabelSize, transform=ax.transAxes,
              horizontalalignment='left', verticalalignment='top')
ax.text(0.01, 0.01, r'$h_3 = -2, h_4 = -2$', size=pltLabelSize, transform=ax.transAxes, horizontalalignment='left', verticalalignment='bottom')





ax = fig.add_subplot(cbar[:,:])

norm= matplotlib.colors.Normalize(0, max(levels)*10**4)
sm = plt.cm.ScalarMappable(norm=norm, cmap=plt.cm.YlOrBr)
sm.set_array([])

cbar = fig.colorbar(sm, ticks=range(0,int(max(levels)*10**4)+1), cax=ax, label=r'Probability ($\times 10^{-4}$)')
cbar.ax.tick_params(labelsize=18)
cbar.ax.set_yticklabels(range(0,int(max(levels)*10**4)+1))
cbar.ax.yaxis.set_ticks_position('right')
cbar.ax.yaxis.set_label_position('right')
cbar.ax.yaxis.label.set_size(20)



ax = fig.add_subplot(cbar2[:,:])

norm= matplotlib.colors.Normalize(-6.5, 0.)
sm = plt.cm.ScalarMappable(norm=norm, cmap=plt.cm.bone)
sm.set_array([])

cbar = fig.colorbar(sm, ticks=range(-6,1), cax=ax, label=r'$\log_{10} (D_{KL})$')
cbar.ax.tick_params(labelsize=18)
cbar.ax.set_yticklabels(range(-6,1))
cbar.ax.yaxis.set_ticks_position('left')
cbar.ax.yaxis.set_label_position('left')
cbar.ax.yaxis.label.set_size(20)






plt.savefig('diffusion_figure_alt_potentials.pdf',bbox_inches='tight')

plt.show()


